// <ExpressionGrammar> -*- C++ -*-

#include <cmath>

#include <boost/phoenix/core.hpp>
#include <boost/phoenix/operator.hpp>
#include <boost/phoenix/fusion.hpp>
#include <boost/phoenix/stl.hpp>
#include <boost/phoenix/object.hpp>

#include "sparta/statistics/ExpressionGrammar.hpp"
#include "sparta/statistics/Expression.hpp"
#include "sparta/statistics/ExpressionNodeVariables.hpp"
#include "sparta/simulation/TreeNode.hpp"
#include "sparta/simulation/Clock.hpp"
#include "sparta/utils/SpartaException.hpp"
#include "sparta/simulation/TreeNodePrivateAttorney.hpp"
#include "sparta/trigger/ContextCounterTrigger.hpp"
#include "sparta/statistics/StatInstCalculator.hpp"
#include <boost/math/constants/constants.hpp>

namespace phoenix = boost::phoenix;

namespace sparta {
    namespace statistics {
        namespace expression {
            namespace helpers {

//struct lazy_pow_
//{
//    template <typename X, typename E>
//    struct result {
//        typedef X type;
//    };
//
//    template <typename X, typename E>
//    X operator()(X x, E e) const {
//        return std::pow(x, e);
//    }
//}; // struct lazy_pow_

/*!
 * \brief Nullary functor adapter for boost phoneix to lazily evaluate function
 * accepting no arguments.
 * \note Always generates Expression instances
 */
struct lazy_nfunc_
{
    template<class> struct result;

    template<typename F, typename F1>
    struct result<F(F1)> { typedef Expression type; };

    template <typename F>
    Expression operator()(F f) const {
        return f();
    }
}; // struct lazy_nfunc_

/*!
 * \brief Unary functor adapter for boost phoneix to lazily evaluate function
 * accepting 1 argument
 */
struct lazy_ufunc_
{
    template<class> struct result;

    template<typename F, typename F1, typename A1>
    struct result<F(F1,A1)> { typedef A1& type; };

    template <typename F, typename A1>
    A1& operator()(F f, A1& a1) const {
        a1 = f(a1);
        return a1;
    }
}; // struct lazy_ufunc_

/*!
 * \brief Binary functor adapter for boost phoneix to lazily evaluate function
 * accepting 2 arguments
 */
struct lazy_bfunc_
{
    template<class> struct result;

    template<typename F, typename F1, typename A1, typename A2>
    struct result<F(F1,A1,A2)> { typedef A1& type; };

    template <typename F, typename A1, typename A2>
    A1& operator()(F f, A1& a1, A2& a2) const {
        a1 = f(a1, a2);
        return a1;
    }
}; // struct lazy_bfunc_

/*!
 * \brief Ternary functor adapter for boost phoneix to lazily evaluate function
 * accepting 3 arguments
 */
struct lazy_tfunc_
{
    template<class> struct result;

    template<typename F, typename F1, typename A1, typename A2, typename A3>
    struct result<F(F1,A1,A2,A3)> { typedef A1& type; };

    template <typename F, typename A1, typename A2, typename A3>
    A1& operator()(F f, A1& a1, A2& a2, A3& a3) const {
        a1 = f(a1, a2, a3);
        return a1;
    }
}; // struct lazy_tfunc_

/*!
 * \brief Functor class for generating Binary Functions based on C++ builtin
 * functions (e.g. pow)
 *
 * \example
 * \code
 * lazy_builtin_bfunc_ lbbf("pow", &std::pow<double>);
 * // Given: Expression a(2), b(3);
 * Expression c = lbbf(a,b);
 * assert(c.evaluate() == 8);
 * \endcode
 */
struct lazy_builtin_bfunc_
{
    typedef double(*fxn_t)(double, double);

    /*!
     * \brief Name of the function generated by an instance of this class
     */
    std::string name_;

    /*!
     * \brief Funciton pointer to binary function
     */
    const fxn_t fxn_;

    /*!
     * \brief Construct the lazy binary f unc expression generator
     * \param name Name given to the generated function node when invoked
     * \param fxn Function attached to the function node when invoked
     */
    lazy_builtin_bfunc_(const std::string& name, fxn_t fxn)
        : name_(name),
          fxn_(fxn)
    {
        sparta_assert(fxn_ != nullptr);
    }

    template<class> struct result;

    template<typename F, typename A1, typename A2>
    struct result<F(A1,A2)> { typedef A1& type; };

    /*!
     * \brief Invoke this lazy function to generate a new binary function
     * Expression
     */
    template <typename A1, typename A2>
    A1& operator()(A1& a1, A2& a2) const {
        a1 = Expression(name_, fxn_, a1, a2);
        return a1;
    }
}; // struct lazy_builtin_bfunc_


/*!
 * \brief Functor class for generating StatVariable expression items
 */
class lazy_gen_var_
{
    TreeNode* n_; //!< Context node
    std::vector<const TreeNode*>& used_; //!< Disallowed nodes

public:

    /*!
     * \brief Constructor
     * \param n Context node in which to search for children
     * \param used Children already used in a parent expression. These nodes
     * must be rejected as they would create cycles if encountered
     */
    lazy_gen_var_(TreeNode* n, std::vector<const TreeNode*>& used) :
        n_(n),
        used_(used)
    { }

    template <typename A1>
    struct result { typedef Expression type; };

    template <typename A1>
    Expression operator()(A1 a1) const
    {
        if(n_ == nullptr){
            // Must construct with non-null to actually use
            throw SpartaException("No context TreeNode specified for expression parsing. Cannot "
                                "resolve \"") << a1 << "\" to a simulation object because there is "
                                "no context";
        }

        const TreeNode* n = TreeNodePrivateAttorney::getChild(n_, a1);

        // Check for cycles
        if(std::find(used_.begin(), used_.end(), n) != used_.end()){
            SpartaException ex("Cycle detected in a sparta statistic expression: ");
            for(auto& un : used_){
                ex << un->getLocation() << " -> ";
            }
            ex << '[' << n->getLocation() << ']';
            throw ex;
        }

        std::shared_ptr<StatInstCalculator> calculator;
        if(n == nullptr) {
            calculator = trigger::ContextCounterTrigger::findRegisteredContextCounterAggregateFcn(
                n_->getRoot(), a1);
            if(calculator) {
                n = calculator->getNode();
            }
        }

        if(n == nullptr) {
            sparta_assert(calculator == nullptr);
            SpartaException ex("While parsing the expression or term of expression: '");
            ex << a1 << "', SPARTA was unable to find a tree node that matched this name.";
            throw ex;
        }

        used_.push_back(n); // Add to used list

        StatVariable * sv = nullptr;
        if(calculator == nullptr) {
            sv = new StatVariable(n, used_); // Throws if cannot convert
        } else {
            sv = new StatVariable(calculator, used_);
        }

        used_.pop_back(); // Remove the stat so it can be used higher up or by
                          // other sibling expressions in the expression

        return Expression(sv);
    }
}; // class lazy_gen_var_

} // namespace helpers

namespace functions {

/*!
 * \brief Binary function if statement. If x is nan or inf, returns y. otherwise returns x.
 */
static double ifnan(double x, double y)
{
    if(isnan(x) || isinf(x)){
        return y;
    }
    return x;
}

/*!
 * \brief Ternary function if statement. If cond is nonzero, returns opt_nonzero. Otherwise returns opt_zero
 */
static double if_function(double cond, double opt_nonzero, double opt_zero)
{
    if(cond != 0){
        return opt_nonzero;
    }
    return opt_zero;
}
            } // namespace functions

            namespace grammar {

ExpressionGrammar::constants_::constants_()
{
    // Constants
    this->add
        ("c_pi",            Expression(boost::math::constants::pi<double>()))
        ("c_root_pi",       Expression(boost::math::constants::root_pi<double>()))
        ("c_root_half_pi",  Expression(boost::math::constants::root_half_pi<double>()))
        ("c_root_two_pi",   Expression(boost::math::constants::root_two_pi<double>()))
        ("c_root_ln_four",  Expression(boost::math::constants::root_ln_four<double>()))
        ("c_e",             Expression(boost::math::constants::e<double>()))
        ("c_half",          Expression(boost::math::constants::half<double>()))
        ("c_euler",         Expression(boost::math::constants::euler<double>()))
        ("c_root_two",      Expression(boost::math::constants::root_two<double>()))
        ("c_ln_two",        Expression(boost::math::constants::ln_two<double>()))
        ("c_ln_ln_two",     Expression(boost::math::constants::ln_ln_two<double>()))
        ("c_third",         Expression(boost::math::constants::third<double>()))
        ("c_twothirds",     Expression(boost::math::constants::twothirds<double>()))
        ("c_pi_minus_three",Expression(boost::math::constants::pi_minus_three<double>()))
        ("c_four_minus_pi", Expression(boost::math::constants::four_minus_pi<double>()))
        ("c_nan",           Expression(NAN))
        ("c_inf",           Expression(INFINITY))
        ;
}

ExpressionGrammar::builtin_vars_::builtin_vars_(TreeNode* n,
                                                std::vector<const TreeNode*>& used)
{
    sparta_assert(nullptr != n,
                      "cannot construct ExpressionGrammar::builtin_vars_ with a null context");

    static auto get_clock_from_node = [](TreeNode* n) -> const Clock & {
        //If we are tied to a clock object directly, return it.
        auto clk = n->getClock();
        if (clk) {
            return *clk;
        }

        //If this node is not tied to a clock directly, throw immediately
        //(notice this method returns a Clock reference... returning null
        //is not valid by design)
        throw SpartaException("Unable to determine clock from the context: ")
            << n->getLocation();

        sparta_assert(!"Unreachable");
        static Clock clock("dummy", nullptr);
        return clock;
    };

    // Global Variables
    this->add
        ("g_ticks" ,       [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getCurrentTicksROCounter(), used));
        })
        ("g_seconds",      [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getSecondsStatisticDef(), used));
        })
        ("g_milliseconds", [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getCurrentMillisecondsStatisticDef(), used));
        })
        ("g_microseconds", [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getCurrentMicrosecondsStatisticDef(), used));
        })
        ("g_nanoseconds",  [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getCurrentNanosecondsStatisticDef(), used));
        })
        ("g_picoseconds",  [n, &used]() -> Expression {
            const Clock & clk = get_clock_from_node(n);
            auto scheduler = clk.getScheduler();
            sparta_assert(scheduler);
            return Expression(new StatVariable(&scheduler->getCurrentPicosecondsROCounter(), used));
        })
        ;

    // Local Variables
    this->add
        ("cycles", [n, &used]() -> Expression {
            const Clock* clk = n->getClock();
            if(!clk){
                throw SpartaException("Unable to determine clock from the context: ")
                    << n->getLocation();
            }
            //! \todo Use this node instead of a statvariable ?
            return Expression(new StatVariable(&clk->getCyclesROCounter(), used));
        })
        ;
    this->add
        ("freq_mhz", [n]() -> Expression {
            const Clock* clk = n->getClock();
            if(!clk){
                throw SpartaException("Unable to determine clock from the context: ")
                    << n->getLocation();
            }
            return Expression(clk->getFrequencyMhz());
        })
        ;
}

ExpressionGrammar::variable_::variable_(sparta::TreeNode* n,
                                        std::vector<const TreeNode*>& used) :
    variable_::base_type(start)
{
    sparta_assert(nullptr != n,
                      "cannot construct ExpressionGrammar::variable_ with a null context");

    using qi::ascii::char_;
    using qi::_val;

    // Variable factory
    helpers::lazy_gen_var_ lgv(n, used);
    phoenix::function<helpers::lazy_gen_var_> lazy_gen_var(lgv);

    start = str [_val = lazy_gen_var(qi::_1)];

    // TreeNode names: All alphanumeric characters, and "_.[]"
    str %= +(char_("a-zA-Z0-9_\\.\\[\\]"));
}

ExpressionGrammar::ufunc_::ufunc_(std::vector<const TreeNode*>& used)
{
    // Math Utils
    this->add
        ("abs"  ,   [](Expression& a) -> Expression {return Expression("abs",     (fptr_dd_t)&std::fabs, a);}) // Behave like fabs
        ("fabs" ,   [](Expression& a) -> Expression {return Expression("fabs",    (fptr_dd_t)&std::fabs, a);})
        ("acos" ,   [](Expression& a) -> Expression {return Expression("acos",    (fptr_dd_t)&std::acos, a);})
        ("asin" ,   [](Expression& a) -> Expression {return Expression("asin",    (fptr_dd_t)&std::asin, a);})
        ("atan" ,   [](Expression& a) -> Expression {return Expression("atan",    (fptr_dd_t)&std::atan, a);})
        ("ceil" ,   [](Expression& a) -> Expression {return Expression("ceil",    (fptr_dd_t)&std::ceil, a);})
        ("trunc",   [](Expression& a) -> Expression {return Expression("trunc",   (fptr_dd_t)&std::trunc, a);})
        ("round",   [](Expression& a) -> Expression {return Expression("round",   (fptr_dd_t)&std::round, a);})
        ("cos"  ,   [](Expression& a) -> Expression {return Expression("cos",     (fptr_dd_t)&std::cos, a);})
        ("cosh" ,   [](Expression& a) -> Expression {return Expression("cosh",    (fptr_dd_t)&std::cosh, a);})
        ("exp"  ,   [](Expression& a) -> Expression {return Expression("exp",     (fptr_dd_t)&std::exp, a);})
        ("exp2" ,   [](Expression& a) -> Expression {return Expression("exp2",    (fptr_dd_t)&std::exp2, a);})
        ("floor",   [](Expression& a) -> Expression {return Expression("floor",   (fptr_dd_t)&std::floor, a);})
        ("ln"   ,   [](Expression& a) -> Expression {return Expression("ln",      (fptr_dd_t)&std::log, a);})
        ("log2" ,   [](Expression& a) -> Expression {return Expression("log2",    (fptr_dd_t)&std::log2, a);})
        ("log10",   [](Expression& a) -> Expression {return Expression("log10",   (fptr_dd_t)&std::log10, a);})
        ("sin"  ,   [](Expression& a) -> Expression {return Expression("sin",     (fptr_dd_t)&std::sin, a);})
        ("sinh" ,   [](Expression& a) -> Expression {return Expression("sinh",    (fptr_dd_t)&std::sinh, a);})
        ("sqrt" ,   [](Expression& a) -> Expression {return Expression("sqrt",    (fptr_dd_t)&std::sqrt, a);})
        ("cbrt" ,   [](Expression& a) -> Expression {return Expression("cbrt",    (fptr_dd_t)&std::cbrt, a);})
        ("tan"  ,   [](Expression& a) -> Expression {return Expression("tan",     (fptr_dd_t)&std::tan, a);})
        ("tanh" ,   [](Expression& a) -> Expression {return Expression("tanh",    (fptr_dd_t)&std::tanh, a);})
        ("isnan",   [](Expression& a) -> Expression {return Expression("isnan",   (fptr_bd_t)&std::isnan, a);})
        ("isinf",   [](Expression& a) -> Expression {return Expression("isinf",   (fptr_bd_t)&std::isinf, a);})
        ("signbit", [](Expression& a) -> Expression {return Expression("signbit", (fptr_bd_t)&std::signbit, a);})
        ("logb",    [](Expression& a) -> Expression {return Expression("logb",    (fptr_dd_t)&std::logb, a);})
        ("erf",     [](Expression& a) -> Expression {return Expression("erf",     (fptr_dd_t)&std::erf, a);})
        ("erfc",    [](Expression& a) -> Expression {return Expression("erfc",    (fptr_dd_t)&std::erfc, a);})
        ("lgamma",  [](Expression& a) -> Expression {return Expression("lgamma",  (fptr_dd_t)&std::lgamma, a);})
        ("tgamma",  [](Expression& a) -> Expression {return Expression("tgamma",  (fptr_dd_t)&std::tgamma, a);})
        ;

    // Parameterized Variables
    this->add
        ("cycles", [&used](Expression& a) -> Expression {
            const Clock* clk = a.getClock();
            if(!clk){
                throw SpartaException("Unable to determine a Clock from the expression: ") << a;
            }
            //! \todo Use this node instead of a statvariable ?
            return Expression(new StatVariable(&clk->getCyclesROCounter(), used));
        })
        ;

}

ExpressionGrammar::bfunc_::bfunc_(const std::vector<const TreeNode*>& used)
{
    (void) used;

    this->add
        ("pow"  ,
         [](Expression& a, Expression& b) -> Expression {
            return Expression("pow", (fptr_ddd_t)&std::pow, a, b);
         })
        ("atan2",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("atan2", (fptr_ddd_t)&std::atan2, a, b);
         })
        ("min",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("min", (fptr_drdrdrt)&std::min<double>, a, b);
         })
        ("max",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("max", (fptr_drdrdrt)&std::max<double>, a, b);
         })
        ("ifnan",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("ifnan", (fptr_ddd_t)&functions::ifnan, a, b);
         })
        ("fmod",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("fmod", (fptr_ddd_t)&std::fmod, a, b);
         })
        ("remainder",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("remainder", (fptr_ddd_t)&std::remainder, a, b);
         })
        ("hypot",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("hypot", (fptr_ddd_t)&std::hypot, a, b);
         })
        ("is_greater",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_greater", std::greater<double>(), a, b);
         })
        ("is_lesser",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_lesser", std::less<double>(), a, b);
         })
        ("is_equal",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_equal", std::equal_to<double>(), a, b);
         })
        ("is_not_equal",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_not_equal", std::not_equal_to<double>(), a, b);
         })
        ("is_greater_equal",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_greater_equal", std::greater_equal<double>(), a, b);
         })
        ("is_lesser_equal",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("is_lesser_equal", std::less_equal<double>(), a, b);
         })
        ("logical_and",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("logical_and", std::logical_and<double>(), a, b);
         })
        ("logical_or",
         [](Expression& a, Expression& b) -> Expression {
            return Expression("logical_or", std::logical_or<double>(), a, b);
         })
        ;
}

ExpressionGrammar::tfunc_::tfunc_(const std::vector<const TreeNode*>& used)
{
    (void) used;

    this->add
        ("cond",
         [](Expression& a, Expression& b, Expression& c) -> Expression {
            return Expression("cond", (fptr_dddd_t)&functions::if_function, a, b, c);
         })
        ;
}

ExpressionGrammar::ExpressionGrammar(sparta::TreeNode* root,
                                     std::vector<const TreeNode*>& used) :
    ExpressionGrammar::base_type(expression),
    builtin_vars(root, used),
    ufunc(used),
    bfunc(used),
    tfunc(used),
    var(root, used),
    root_(root)
{
    (void) root_;
    sparta_assert(nullptr != root,
                      "cannot construct ExpressionGrammar with a null context");

    namespace qi = qi;
    using qi::real_parser;
    using qi::real_policies;
    using qi::no_case;
    using qi::_val;
    using phoenix::ref;
    real_parser<double,real_policies<double> > real;

    helpers::lazy_builtin_bfunc_ lzbbf_pow("pow", &std::pow);
    boost::phoenix::function<helpers::lazy_builtin_bfunc_> lazy_pow(lzbbf_pow);
    boost::phoenix::function<helpers::lazy_nfunc_> lazy_nfunc;
    boost::phoenix::function<helpers::lazy_ufunc_> lazy_ufunc;
    boost::phoenix::function<helpers::lazy_bfunc_> lazy_bfunc;
    boost::phoenix::function<helpers::lazy_tfunc_> lazy_tfunc;

    expression =
        term                       [_val =  qi::_1]
        >> *(  ('+' >> term        [_val += qi::_1])
            |  ('-' >> term        [_val -= qi::_1])
            )
        ;

    term =
        factor                     [_val =  qi::_1]
        >> *(  ('*' >> factor      [_val *= qi::_1])
            |  ('/' >> factor      [_val /= qi::_1])
            )
        ;

    factor =
        primary                    [_val =  qi::_1]
        >> *(  ("**" >> factor     [_val = lazy_pow(_val, qi::_1)])
            )
        ;

    primary =
        real                       [_val =  qi::_1]
       |    '(' >> expression      [_val =  qi::_1] >> ')'
       |    ('-' >> primary        [_val = -qi::_1])
       |    ('+' >> primary        [_val =  qi::_1])
       |    (no_case[ufunc] >> '(' >> expression >> ')')
                                   [_val = lazy_ufunc(qi::_1, qi::_2)]
       |    (no_case[bfunc] >> '(' >> expression >> ','
                                   >> expression >> ')')
                                   [_val = lazy_bfunc(qi::_1, qi::_2, qi::_3)]
       |    (no_case[tfunc] >> '(' >> expression >> ','
                                   >> expression >> ','
                                   >> expression >> ')')
                                   [_val = lazy_tfunc(qi::_1, qi::_2, qi::_3, qi::_4)]
       |    no_case[constants]     [_val =  qi::_1]
       |    no_case[builtin_vars]  [_val = lazy_nfunc(qi::_1)]
       |    no_case[var]           [_val = qi::_1]
        ;

    expression.name("expression");
    term.name("term");
    factor.name("factor");
    primary.name("primary");

    qi::on_error<qi::fail>
    (
        expression
      , std::cout
            << phoenix::val("Error! Expecting ")
            << qi::_4                               // what failed?
            << phoenix::val(" here: \"")
            << phoenix::construct<std::string>(qi::_3, qi::_2)   // iterators to error-pos, end
            << phoenix::val("\"")
            << std::endl
    );

    // Debugging switches
    //qi::debug(expression);
    //qi::debug(term);
    //qi::debug(factor);
    //qi::debug(primary);
}

            } // namespace grammar
        } // namespace expression
    } // namespace statistics
} // namespace sparta
